# BSD 2-CLAUSE LICENSE

# Redistribution and use in source and binary forms, with or without modification,
# are permitted provided that the following conditions are met:

# Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
# #ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# original author: Kaixu Yang
"""Automatically infers seasonality Fourier series orders."""

from dataclasses import dataclass
from enum import Enum
from typing import List
from typing import Optional

import numpy as np
import pandas as pd
from plotly import graph_objs as go
from plotly.colors import DEFAULT_PLOTLY_COLORS
from plotly.subplots import make_subplots
from sklearn.linear_model import LinearRegression
from sklearn.linear_model import RidgeCV
from sklearn.linear_model import SGDRegressor

from greykite.common.constants import TIME_COL
from greykite.common.constants import VALUE_COL
from greykite.common.constants import TimeFeaturesEnum
from greykite.common.features.timeseries_features import add_time_features_df
from greykite.common.features.timeseries_features import fourier_series_multi_fcn
from greykite.common.features.timeseries_features import get_default_origin_for_time_vars


class TrendAdjustMethodEnum(Enum):
    """The methods that are available for adjusting trend in
    `~greykite.algo.common.seasonality_inferrer.SeasonalityInferrer.infer_fourier_series_order`.
    """
    seasonal_average = "seasonal_average"
    """Calculates the average within each seasonal period and removes it."""
    overall_average = "overall_average"
    """Calculates the average of the whole timeseries and removes it."""
    spline_fit = "spline_fit"
    """Fits a spline with no knots (polynomial) with a certain degree and removes it."""
    none = "none"
    """Does not adjust trend."""


@dataclass
class SeasonalityInferConfig:
    """A dataclass to pass the parameters for
    `~greykite.algo.common.seasonality_inferrer.SeasonalityInferrer.infer_fourier_series_order`.

    Attributes
    ----------
    seas_name : `str`
        Required.
        The seasonality component name.
        Will be used to distinguish the results.
    col_name : `str`
        Required.
        The column name used to generate seasonality Fourier series.
        Must be in ``df`` or can be generated by
        `~greykite.common.features.timeseries_features.build_time_features_df`.
        See `~greykite.common.features.timeseries_features.fourier_series_multi_func`.
    period : `float`
        Required.
        The period corresponding to ``col_name``.
        See `~greykite.common.features.timeseries_features.fourier_series_multi_func`.
    max_order : `int`
        Required.
        The maximum Fourier series order to fit.
    adjust_trend_method : `str` or None, default "seasonal_average"
        The method used to adjust trend.
        Supported methods are in `AdjustTrendMethodEnum`.
        None values are default to "seasonal_average" with subtracting yearly average as the default.
    adjust_trend_param : `dict` or None, default None
        Additional parameters for adjusting the trend.
        For valid options, see
        `~greykite.algo.common.seasonality_inferrer.SeasonalityInferrer._adjust_trend`.
    fit_algorithm : `str` or None, default "ridge"
        The algorithm used to fit the seasonality.
        Supported algorithms are "linear", "ridge" and "sgd".
        None values are default to "ridge".
    plotting : `bool` or None, default False
        Whether to generate plots.
        If True, the returned dictionary will have plot via the "fig" key.
        Can turn this off to speed up the process.
        None values are default to False.
    tolerance : `float` or None, default 0.0
        A tolerance on the criterion to allow a smaller order.
        For example, if AIC's minimum is 100 and ``tolerance`` is 0.1,
        then the function will find the smallest order that has AIC less than or equal to 110.
        None values are default to 0.0.
    aggregation_period : `str` or None, default None
        The aggregation period before fitting the Fourier series.
        Having aggregation to eliminate shorter seasonal periods may help get more accurate orders.
        But also making sure the number of observations after aggregation is sufficient.
        None corresponds to no aggregation.
    offset : `int` or None, default 0
        The offset order to be added to the inferred orders.
        The order after adding offset can not be negative.
    criterion : `str` or None, default "bic"
        The criterion to pick the most appropriate orders.
        Supported criteria are "aic" and "bic".
        None values are default to "bic".
    """
    seas_name: str
    col_name: str
    period: float
    max_order: int
    adjust_trend_method: str = TrendAdjustMethodEnum.seasonal_average.name
    adjust_trend_param: Optional[dict] = None
    fit_algorithm: str = "ridge"
    tolerance: float = 0.0
    plotting: bool = False
    aggregation_period: Optional[str] = None
    offset: int = 0
    criterion: str = "bic"


class SeasonalityInferrer:
    """A class to infer appropriate Fourier series orders in different
    seasonality components.

    The method allows users to:

        - optionally remove the trend with different methods.
          Available methods are in `TrendAdjustMethodEnum`.
        - optionally do an aggregation.
        - fits the seasonality component with different Fourier series orders.
        - calculates the AIC/BIC of the fits.
        - choose the most appropriate order with AIC or BIC and an optional tolerance.
        - plot the investigations.

    Attributes
    ----------
    df : `pandas.DataFrame` or None
        The input timeseries.
    time_col : `str` or None
        The column name for timestamps in ``df``.
    value_col : `str` or None
        The column name for values in ``df``.
    fourier_series_orders : `dict` or None
        The inferred Fourier series orders.
        The keys are the seasonality component names.
        The values are the inferred best orders according to the config.
    df_features : `pandas.DataFrame` or None
        The cached dataframe with time features.
        Building this df is slow for large dataset.
        We cache it the first time we build it for subsequent uses.
    """

    def __init__(self):
        # Set by inputs and preprocessing.
        self.df: Optional[pd.DataFrame] = None
        self.time_col: Optional[str] = None
        self.value_col: Optional[str] = None
        # Set by inferrer.
        self.fourier_series_orders: Optional[dict] = None
        self.df_features: Optional[pd.DataFrame] = None

    FITTED_TREND_COL = "FITTED_TREND"

    def infer_fourier_series_order(
            self,
            df: pd.DataFrame,
            configs: List[SeasonalityInferConfig],
            time_col: str = TIME_COL,
            value_col: str = VALUE_COL,
            adjust_trend_method: Optional[str] = None,
            adjust_trend_param: Optional[dict] = None,
            fit_algorithm: Optional[str] = None,
            tolerance: Optional[float] = None,
            plotting: Optional[bool] = None,
            aggregation_period: Optional[str] = None,
            offset: Optional[int] = None,
            criterion: Optional[str] = None) -> dict:
        """Infers the most appropriate Fourier series order.
        Can infer multiple seasonality components with multiple configs at the same time.
        The configurations for each component are passed as a list of
        `~greykite.algo.common.seasonality_inferrer.SeasonalityInferConfig` object.
        To override a parameter for all configs, pass it via this function's parameter.

        For each seasonality component,
        the method first does an optional trend removal via grouped average or spline fit.
        For example, for yearly seasonality, one option is to remove the average of each year
        from the time series.
        The seasonality pattern is clearer and dominates after the trend removal.

        Next it does an optional aggregation to emphasize the current seasonality.
        For example, for yearly seasonality, it can do a weekly aggregation so that
        the weekly seasonality won't be mixed when modeling yearly seasonality.

        Then it fits seasonality model using Fourier series with orders up to a certain max_order,
        and computes the AIC/BIC of the models.

        The final order will be selected based on the criterion with a tolerance adjustment.
        A pre-specified offset can also be added to the selected order for adjustment.

        Parameters
        ----------
        df : `pandas.DataFrame`
            The input timeseries.
        configs : `list` [`SeasonalityInferConfig`]
            A list of `~greykite.algo.common.seasonality_inferrer.SeasonalityInferConfig`
            objects. Each element corresponds to the config for a seasonality component.
            For example, if you would like to infer seasonality orders for yearly seasonality and
            weekly seasonality, you need to provide a list of two configs.
        time_col : `str`
            The column name for timestamps in ``df``.
        value_col : `str`
            The column name for values in ``df``.
        adjust_trend_method : `str` or None, default None
            The methods used to adjust trend.
            Supported methods are in `AdjustTrendMethodEnum`.
            If not None, value is used to override all configs.
        adjust_trend_param : `dict` or None, default None
            Additional parameters for adjusting trend.
            For valid options, see
            `~greykite.algo.common.seasonality_inferrer.SeasonalityInferrer._adjust_trend`.
            If not None, value is used to override all configs.
        fit_algorithm : `str` or None, default None
            The algorithms used to fit the seasonality.
            Supported algorithms are "linear", "ridge" and "sgd".
            If not None, value is used to override all configs.
        plotting : `bool` or None, default None
            Whether to generate plots.
            If True, the returned dictionary will have plot via the "fig" key.
            Can turn this off to speed up the process.
            If not None, value is used to override all configs.
        tolerance : `float` or None, default None
            A tolerance on the criterion to allow a smaller order.
            For example, if AIC's minimum is 100 and ``tolerance`` is 0.1,
            then the function will find the smallest order that has AIC less than or equal to 110.
            If not None, value is used to override all configs.
        aggregation_period : `str` or None, default None
            The aggregation periods before fitting the Fourier series.
            Having aggregation to eliminate shorter seasonal periods may help get more accurate orders.
            But also make sure the number of observations after aggregation is sufficient.
            (At least 2 *  max_order + 1 to have a unique solution for the regression problem)
            If not None, value is used to override all configs.
        offset : `int` or None, default None
            The offset order to be added to the inferred orders.
            The orders after applying offsets can not be negative.
            If not None, value is used to override all configs.
        criterion : `str` or None, default None
            The criteria to pick the most appropriate orders.
            If not None, value is used to override all configs.

        Returns
        -------
        result : `dict`
            The result dictionary with the following keys:

                - "result": a list of result dictionaries from the inferring methods.
                  The keys are:

                    - "seas_name": the seasonality name.
                    - "orders": the Fourier series orders fitted.
                    - "aics": the fitted AICs.
                    - "bics": the fitted BICs.
                    - "best_aic_order": the order corresponding to the best feasible AIC.
                    - "best_bic_order": the order corresponding to the best feasible BIC.
                    - "fig": the diagnostic figure.

                - "best_orders": a dictionary of seasonality component names and their
                  inferred Fourier series orders.
        """
        # Checks and processes the input parameters.
        configs = self._process_params(
            configs=configs,
            adjust_trend_method=adjust_trend_method,
            adjust_trend_param=adjust_trend_param,
            fit_algorithm=fit_algorithm,
            tolerance=tolerance,
            plotting=plotting,
            aggregation_period=aggregation_period,
            offset=offset,
            criterion=criterion
        )
        # Sets class attributes.
        self.df = df.copy()
        self.df[time_col] = pd.to_datetime(self.df[time_col])
        self.time_col = time_col
        self.value_col = value_col
        self.df_features = None  # resets ``df_features`` for every new call.
        # Iteratively infer the Fourier series order for each component.
        result = []
        for config in configs:
            # Adjusts the trend and does aggregation.
            df_adj = self._process_df(
                df=self.df,
                time_col=time_col,
                value_col=value_col,
                adjust_trend_method=config.adjust_trend_method,
                adjust_trend_params=config.adjust_trend_param,
                aggregation_period=config.aggregation_period
            )
            # Infers the seasonality Fourier series order.
            seasonality_result = self._gets_seasonality_fitting_metrics(
                df=df_adj,
                time_col=time_col,
                value_col=value_col,
                seas_name=config.seas_name,
                col_name=config.col_name,
                period=config.period,
                max_order=config.max_order,
                fit_algorithm=config.fit_algorithm,
                plotting=config.plotting,
                tolerance=config.tolerance,
                offset=config.offset
            )
            result.append(seasonality_result)

        # Calculates best orders according to the criteria.
        best_orders = {
            config.seas_name: result[i]["best_aic_order"]
            if config.criterion == "aic"
            else result[i]["best_bic_order"]
            for i, config in enumerate(configs)
        }

        self.fourier_series_orders = best_orders

        return {
            "result": result,
            "best_orders": best_orders
        }

    def _process_params(
            self,
            configs: List[SeasonalityInferConfig],
            adjust_trend_method: Optional[str],
            adjust_trend_param: Optional[dict],
            fit_algorithm: Optional[str],
            tolerance: Optional[float],
            plotting: Optional[bool],
            aggregation_period: Optional[str],
            offset: Optional[int],
            criterion: Optional[str]) -> List[SeasonalityInferConfig]:
        """Checks and overrides the input parameters, and makes sure the values are valid.

        Parameters
        ----------
        configs : `list` [`SeasonalityInferConfig`]
            A list of `~greykite.algo.common.seasonality_inferrer.SeasonalityInferConfig`
            objects. Each element corresponds to the config for a seasonality component.
            For example, if you would like to infer seasonality orders for yearly seasonality and
            weekly seasonality, you need to provide a list of two configs.
        adjust_trend_method : `str` or None, default None
            The methods used to adjust trend.
            Supported methods are in `AdjustTrendMethodEnum`.
            If not None, value is used to override all configs.
        adjust_trend_param : `dict` or None, default None
            Additional parameters for adjusting trend.
            For valid options, see
            `~greykite.algo.common.seasonality_inferrer.SeasonalityInferrer._adjust_trend`.
            If not None, value is used to override all configs.
        fit_algorithm : `str` or None, default None
            The algorithms used to fit the seasonality.
            Supported algorithms are "linear", "ridge" and "sgd".
            If not None, value is used to override all configs.
        plotting : `bool` or None, default None
            Whether to generate plots.
            If True, the returned dictionary will have plot via the "fig" key.
            Can turn this off to speed up the process.
            If not None, value is used to override all configs.
        tolerance : `float` or None, default None
            A tolerance on the criterion to allow a smaller order.
            For example, if AIC's minimum is 100 and ``tolerance`` is 0.1,
            then the function will find the smallest order that has AIC less than or equal to 110.
            If not None, value is used to override all configs.
        aggregation_period : `str` or None, default None
            The aggregation periods before fitting the Fourier series.
            Having aggregation to eliminate shorter seasonal periods may help get more accurate orders.
            But also make sure the number of observations after aggregation is sufficient.
            (At least 2 *  max_order + 1 to have a unique solution for the regression problem)
            If not None, value is used to override all configs.
        offset : `int` or None, default None
            The offset order to be added to the inferred orders.
            The orders after applying offsets can not be negative.
            If not None, value is used to override all configs.
        criterion : `str` or None, default None
            The criteria to pick the most appropriate orders.
            If not None, value is used to override all configs.

        Returns
        -------
        configs : `list` [`SeasonalityInferConfig`]
            A list of `~greykite.algo.common.seasonality_inferrer.SeasonalityInferConfig`
            objects. The values are checked and overridden if needed.
        """
        # For each of the optional parameters,
        # checks the value is valid,
        # overrides the value is needed,
        # and fills the None values with defaults.
        for (param_name, override_value, default_value, allowed_values) in (
                ("adjust_trend_method", adjust_trend_method,
                 TrendAdjustMethodEnum.seasonal_average.name,
                 TrendAdjustMethodEnum.__dict__["_member_names_"]),
                ("adjust_trend_param", adjust_trend_param, None, None),
                ("fit_algorithm", fit_algorithm, "ridge", ["ridge", "linear", "sgd"]),
                ("tolerance", tolerance, 0.0, None),
                ("plotting", plotting, False, [True, False]),
                ("aggregation_period", aggregation_period, None, None),
                ("offset", offset, 0, None),
                ("criterion", criterion, "bic", ["aic", "bic"])):
            configs = self._apply_default_value(
                configs=configs,
                param_name=param_name,
                override_value=override_value,
                default_value=default_value,
                allowed_values=allowed_values
            )
        return configs

    @staticmethod
    def _apply_default_value(
            configs: List[SeasonalityInferConfig],
            param_name: str,
            override_value: any,
            default_value: any,
            allowed_values: Optional[List[any]] = None) -> List[SeasonalityInferConfig]:
        """For the corresponding parameter in each config in ``configs``,
        replaces `None` with default values;
        overrides all values if ``override_value`` is not `None`;
        checks the final values are in ``allowed_values`` if the latter is not `None`.

        Parameters
        ----------
        configs : `list` [`SeasonalityInferConfig`]
            A list of `~greykite.algo.common.seasonality_inferrer.SeasonalityInferConfig`
            objects. Each element corresponds to the config for a seasonality component.
            For example, if you would like to infer seasonality orders for yearly seasonality and
            weekly seasonality, you need to provide a list of two configs.
        param_name : `str`
            The name of the parameter.
            Will be used in error message.
        default_value : `any`
            The default value to replace None.
        allowed_values : `any`, `list` [`any`] or None, default None
            The list of valid values.
            If any given value is not valid, an error will be raised.
            Leave as None to skip checking.

        Returns
        -------
        configs : `list` [`SeasonalityInferConfig`]
            A list of `~greykite.algo.common.seasonality_inferrer.SeasonalityInferConfig`
            objects. The values are checked and overridden if needed.
        """
        # Checks if ``override_value`` is given but not valid.
        if override_value is not None and allowed_values is not None and override_value not in allowed_values:
            raise ValueError(f"The parameter '{param_name}' has value {override_value}, "
                             f"which is not valid. Valid values are '{allowed_values}'.")

        # If the ``override_value`` is provided and valid,
        # overrides the parameter in all configs.
        if override_value is not None:
            for config in configs:
                setattr(config, param_name, override_value)
            return configs

        # If ``override_value`` is None,
        # checks the validity of each config's parameter and
        # fills None with default value.
        for config in configs:
            param = getattr(config, param_name)
            if param is None:
                setattr(config, param_name, default_value)
            elif allowed_values is not None and param not in allowed_values:
                raise ValueError(f"The parameter '{param_name}' in SeasonalityInferConfig has value {param}, "
                                 f"which is not valid. Valid values are '{allowed_values}'.")
        return configs

    def _adjust_trend(
            self,
            df: pd.DataFrame,
            time_col: str = TIME_COL,
            value_col: str = VALUE_COL,
            method: str = TrendAdjustMethodEnum.seasonal_average.name,
            trend_average_col: str = "year",
            spline_fit_degree: int = 1) -> pd.DataFrame:
        """Adjusts the time series by removing trend.

        There methods to remove trend from time series are listed in
        `~greykite.algo.common.seasonality_inferrer.TrendAdjustMethodEnum`.
        The methods are:

            - "seasonal_average" : subtracts the average from each seasonal period.
            - "overall_average" : subtracts the overall average.
            - "spline_fit" : fits the trend using a spline with no knots up to a certain degree (poly).

        Parameters
        ----------
        df : `pandas.DataFrame`
            The input time series.
        time_col : `str`, default ``TIME_COL``
            The column name for timestamps in ``df``.
        value_col : `str`, default ``VALUE_COL``
            The column name for values in ``df``.
        method : `str`, default `~greykite.algo.common.seasonality_inferrer.TrendAdjustMethodEnum.seasonal_average.name`
            The adjustment method. Must be a method specified in
            `~greykite.algo.common.seasonality_inferrer.TrendAdjustMethodEnum`.
        trend_average_col : `str`, default "year"
            The column name that specifies the categories of calculating seasonal mean.
            Suggested columns are "year", "year_quarter", "year_month", etc.
            Has no effect if ``method`` is not
            `~greykite.algo.common.seasonality_inferrer.TrendAdjustMethodEnum.seasonal_average.name`.
            Must be provided in ``df`` or be generated by
            `~greykite.common.features.timeseries_features import build_time_features_df`.
        spline_fit_degree : `int`, default 1
            The degree of spline to fit the trend.
            Has no effect if ``method`` is not
            `~greykite.algo.common.seasonality_inferrer.TrendAdjustMethodEnum.spline_fit.name`.

        Returns
        -------
        df : `pandas.DataFrame`
            A copy of the original df but with the ``value_col`` subtracting trend
            and an extra column indicating the fitted trend.
            The ``value_col`` will have mean zero if trend is adjusted.
        """
        # Checks if the method is valid.
        if method not in TrendAdjustMethodEnum.__dict__["_member_names_"]:
            raise ValueError(f"The trend adjust method '{method}' is not a valid name. "
                             f"Available methods are {TrendAdjustMethodEnum.__dict__['_member_names_']}.")
        df = df.copy()
        # Subtracts the overall time series average.
        if method == TrendAdjustMethodEnum.overall_average.name:
            df[self.FITTED_TREND_COL] = df[value_col].mean()

        # Subtracts the seasonal average within each seasonal period.
        if method == TrendAdjustMethodEnum.seasonal_average.name:
            df_cols = list(df.columns)
            if trend_average_col not in df.columns:
                if trend_average_col in TimeFeaturesEnum.__dict__["_member_names_"]:
                    if self.df_features is None:
                        df = add_time_features_df(
                            df=df,
                            time_col=time_col,
                            conti_year_origin=get_default_origin_for_time_vars(
                                df=df,
                                time_col=time_col
                            )
                        )
                        self.df_features = df.copy()
                    else:
                        df = self.df_features.copy()
                else:
                    raise ValueError(f"The trend_average_col '{trend_average_col}' is neither found in df "
                                     f"'{list(df.columns)}' nor in the built time features "
                                     f"'{TimeFeaturesEnum.__dict__['_member_names_']}'.")
            df_seasonal_average = (df[[value_col, trend_average_col]]
                                   .groupby(trend_average_col)
                                   .mean()
                                   .reset_index(drop=False)
                                   .rename(columns={value_col: self.FITTED_TREND_COL}))
            df = df.merge(df_seasonal_average, on=trend_average_col)
            df = df[df_cols + [self.FITTED_TREND_COL]]

        # Subtracts the fitted trend.
        if method == TrendAdjustMethodEnum.spline_fit.name:
            if spline_fit_degree < 1:
                raise ValueError(f"Spline degree has be to a positive integer, "
                                 f"but found {spline_fit_degree}.")
            df_features = add_time_features_df(
                df=df,
                time_col=time_col,
                conti_year_origin=get_default_origin_for_time_vars(
                    df=df,
                    time_col=time_col
                )
            )
            df_fit = df_features[[value_col, TimeFeaturesEnum.ct1.name]]
            fit_cols = [TimeFeaturesEnum.ct1.name]
            for degree in range(2, int(spline_fit_degree) + 1):
                df_fit[f"ct{degree}"] = df_fit[TimeFeaturesEnum.ct1.name] ** degree
                fit_cols.append(f"ct{degree}")
            model = SGDRegressor()
            model.fit(df_fit[fit_cols], df_fit[value_col])
            df[self.FITTED_TREND_COL] = model.predict(df_fit[fit_cols])

        df[value_col] -= df[self.FITTED_TREND_COL]
        return df

    def _process_df(
            self,
            df: pd.DataFrame,
            time_col: str,
            value_col: str,
            adjust_trend_method: Optional[str],
            adjust_trend_params: Optional[dict],
            aggregation_period: Optional[str]) -> pd.DataFrame:
        """Does pre-processing which includes:

                - adjust trend (optional)
                - aggregation (optional)

        Parameters
        ----------
        df : `pandas.DataFrame`
            The input time series.
        time_col : `str`
            The column name for timestamps in ``df``.
        value_col : `str`
            The column name for values in ``df``.
        adjust_trend_method : `str` or None
            The adjustment method. Must be a method specified in
            `~greykite.algo.common.seasonality_inferrer.TrendAdjustMethodEnum`.
        adjust_trend_params : `dict` or None
            Additional parameters for `self._adjust_trend`.
        aggregation_period : `str` or None
            The aggregation period after adjusting the trend.

        Returns
        -------
        df_adj : `pandas.DataFrame`
            The df with trend removed.
        """
        if adjust_trend_method is not None:
            if adjust_trend_params is None:
                adjust_trend_params = {}
            df_adj = self._adjust_trend(
                df=df,
                time_col=time_col,
                value_col=value_col,
                method=adjust_trend_method,
                **adjust_trend_params
            )
        else:
            df_adj = df.copy()

        if aggregation_period is not None:
            df_adj = df_adj.resample(aggregation_period, on=time_col).mean().reset_index(drop=False)
        return df_adj

    def _gets_seasonality_fitting_metrics(
            self,
            df: pd.DataFrame,
            time_col: str,
            value_col: str,
            seas_name: str,
            col_name: str,
            period: float,
            max_order: int,
            fit_algorithm: str = "ridge",
            plotting: bool = False,
            tolerance: float = 0.0,
            offset: int = 0) -> dict:
        """Fits the seasonality with Fourier series and find the best order according to a criterion.

        Parameters
        ----------
        df : `pandas.DataFrame`
            The input time series after trend adjustment.
        time_col : `str`
            The column name for timestamps in ``df``.
        value_col : `str`
            The column name for values in ``df``.
        seas_name : `str`
            The name for the seasonality component.
        col_name : `str`
            The column name for the column used to build Fourier series.
            Must be in ``df`` or can be generated by
            `~greykite.common.features.timeseries_features.build_time_features_df`.
            See `~greykite.common.features.timeseries_features.fourier_series_multi_func`.
        period : `float`
            The period corresponding to ``col_name``.
            See `~greykite.common.features.timeseries_features.fourier_series_multi_func`.
        max_order : `int`
            The maximum Fourier series order to fit.
        fit_algorithm : `str`, default "ridge"
            The algorithm used to fit the seasonality.
            Supported algorithms are "linear", "ridge" and "sgd".
        plotting : `bool`, default False
            Whether to generate plots.
            If True, the returned dictionary will have plot via the "fig" key.
            Can turn this off to speed up the process.
        tolerance : `float`, default 0.0
            A tolerance on the criterion to allow a smaller order.
            For example, if AIC's minimum is 100 and ``tolerance`` is 0.1,
            then the function will find the smallest order that has AIC less than or equal to 110.
        offset : `int`, default 0
            The offset to be added to the best orders.

        Returns
        -------
        result : `dict`
            A result dictionary with the following keys:

                - "seas_name": the seasonality name.
                - "orders": the Fourier series orders fitted.
                - "aics": the fitted AICs.
                - "bics": the fitted BICs.
                - "best_aic_order": the order corresponding to the best feasible AIC.
                - "best_bic_order": the order corresponding to the best feasible BIC.
                - "fig": the diagnostic figure.

        """
        # Checks inputs.
        if col_name not in df.columns and col_name not in TimeFeaturesEnum.__dict__["_member_names_"]:
            raise ValueError(f"The column name '{col_name}' is not found in ``df`` or time features.")
        if max_order <= 0 or int(max_order) != max_order:
            raise ValueError(f"The max order must be a positive integer, found {max_order}.")
        models = {
            "ridge": RidgeCV,
            "linear": LinearRegression,
            "sgd": SGDRegressor
        }
        if fit_algorithm not in models:
            raise ValueError(f"The fit_algorithm '{fit_algorithm}' is not supported. "
                             f"Must be one of {list(models.keys())}.")
        model_class = models[fit_algorithm]

        if tolerance < 0:
            raise ValueError(f"The tolerance percentage must be non-negative, found {tolerance}.")

        # Generates Fourier series.
        fourier_func = fourier_series_multi_fcn(
            col_names=[col_name],
            periods=[period],
            orders=[max_order],
            seas_names=[seas_name]
        )
        if col_name not in df.columns:
            df = add_time_features_df(
                df=df,
                time_col=time_col,
                conti_year_origin=get_default_origin_for_time_vars(
                    df=df,
                    time_col=time_col
                )
            )
        seasonality_df = fourier_func(df)["df"]
        df_features = pd.concat([
            df[[value_col]].reset_index(drop=True),
            seasonality_df.reset_index(drop=True)
        ], axis=1)
        df_features.dropna(inplace=True)  # In case ``df_features`` has NANs and fit fails.

        # Fits the models and calculates the metrics.
        # Defines a function to fit seasonality model for a single order.
        # This is later used with the `map` method to fast calculate the AIC/BICs for
        # all orders.
        # See https://wiki.python.org/moin/PythonSpeed/PerformanceTips#Loops
        def fit_seasonality_model(order: int) -> (float, float):
            """A function to take a single order and fit
            the seasonality model. Then it calculates the AIC/BIC of the model.

            Parameters
            ----------
            order : `int`
                The Fourier series order.

            Returns
            -------
            result : `tuple`
                (AIC, BIC) score for the model.
            """
            model = model_class()
            features = [
                col for col in df_features.columns
                if col != value_col
                # After split, the 1st element is sin{order} or cos{order}.
                # Takes the order to compare with the current order.
                and int(col.split("_")[0][3:]) <= order
            ]
            model.fit(df_features[features], df_features[value_col])
            pred = model.predict(df_features[features])
            mse = np.mean((df_features[value_col] - pred) ** 2)
            # The number of features in the current model is "2 * order".
            # In the i.i.d. Gaussian case, the -2 * loglikelihood reduces to n * log(MSE).
            aic = 2 * 2 * order + len(df_features) * np.log(mse)
            bic = np.log(len(df_features)) * 2 * order + len(df_features) * np.log(mse)
            return aic, bic

        # Uses `map` to fast calculate the AIC/BICs.
        orders = list(range(1, max_order + 1))
        abics = list(map(fit_seasonality_model, orders))
        aics = [abic[0] for abic in abics]
        bics = [abic[1] for abic in abics]

        min_aic = min(aics)
        min_bic = min(bics)

        # If a tolerance is allowed, the search will sacrifice some criterion
        # to use a smaller order of Fourier series.
        if tolerance > 0:
            aic_tol = min_aic + abs(min_aic) * tolerance
            bic_tol = min_bic + abs(min_bic) * tolerance
            # Eligible criteria are those that are less than or equal to the tolerance.
            # The original minimums are guaranteed to be in the lists.
            eligible_aics = [aic for aic in aics if aic <= aic_tol]
            eligible_bics = [bic for bic in bics if bic <= bic_tol]
            # The first element is guaranteed to have order no greater than the original minimum.
            min_aic = eligible_aics[0]
            min_bic = eligible_bics[0]

        best_aic_order = max(orders[aics.index(min_aic)] + offset, 0)
        best_bic_order = max(orders[bics.index(min_bic)] + offset, 0)

        # Gets the optional plot.
        fig = None
        if plotting:
            fig = self._make_plot(
                df=df,
                df_features=df_features,
                model_class=model_class,
                time_col=time_col,
                value_col=value_col,
                orders=orders,
                aics=aics,
                bics=bics,
                best_aic_order=best_aic_order,
                best_bic_order=best_bic_order
            )

        return {
            "seas_name": seas_name,
            "orders": orders,
            "aics": aics,
            "bics": bics,
            "best_aic_order": best_aic_order,
            "best_bic_order": best_bic_order,
            "fig": fig
        }

    def _make_plot(
            self,
            df: pd.DataFrame,
            df_features: pd.DataFrame,
            model_class: any,
            time_col: str,
            value_col: str,
            orders: List[int],
            aics: List[float],
            bics: List[float],
            best_aic_order: int,
            best_bic_order: int) -> go.Figure:
        """Makes a figure that contains 2-3 subplots.
        The subplots are:

            - AIC/BIC vs. Fourier series orders.
            - De-trended and aggregated time series vs fitted curves with best AIC/BICs.
            - (Optional) the original time series, the fitted trend and the de-trended and aggregated
              time series.

        Parameters
        ----------
        df : `pandas.DataFrame`
            The input time series after de-trend and aggregation.
        df_features : `pandas.DataFrame`
            The time series with seasonality features generated.
        model_class : `any`
            The class type for the model to fit the seasonality curve.
        time_col : `str`
            The column name for timestamps in ``df``.
        value_col : `str`
            The column name for values in ``df``.
        orders : `list` [`int`]
            The orders used in calculating ``aics`` and ``bics``.
        aics : `list` [`float`]
            The AICs corresponding to ``orders``.
        bics : `list` [`float`]
            The BICs corresponding to ``orders``.
        best_aic_order : `int`
            The best AIC order.
        best_bic_order : `int
            The best BIC order.

        Returns
        -------
        fig : `plotly.graph_object.Figure`
            The figure object.
        """
        rows = 2
        subtitles = (
            "Criteria vs. Fourier series orders",
            "Fitted seasonality on best orders"
        )
        # If both original df and trend adjusted df are available,
        # we include the trend adjustment visualization, too.
        if (self.df is not None and time_col in self.df
                and value_col in self.df and self.FITTED_TREND_COL in df):
            rows += 1
            subtitles += (
                "Original time series and detrended/aggregated time series",
            )
        fig = make_subplots(
            rows=rows,
            cols=1,
            subplot_titles=subtitles,
            vertical_spacing=0.1
        )
        # Adds metric plot.
        metrics = {
            "AIC": [aics, best_aic_order, "dash"],
            "BIC": [bics, best_bic_order, None]
        }
        for i, (metric, value) in enumerate(metrics.items()):
            fig.add_trace(
                go.Scatter(
                    x=orders,
                    y=value[0],
                    name=metric,
                    mode="lines",
                    showlegend=True,
                    line=dict(color=DEFAULT_PLOTLY_COLORS[i], dash=value[2])
                ),
                row=1,
                col=1
            )
            fig.add_vline(
                x=value[1],
                name=f"best_{metric}_order",
                line=dict(color=DEFAULT_PLOTLY_COLORS[i], dash=value[2])
            )
        # Adds fit plot.
        fig.add_trace(
            go.Scatter(
                x=df[time_col],
                y=df[value_col],
                name="Timeseries",
                showlegend=True,
                line=dict(color="black")
            ),
            row=2,
            col=1
        )
        best_orders = {
            "best_AIC_order_fit": best_aic_order,
            "best_BIC_order_fit": best_bic_order
        }
        for i, (name, order) in enumerate(best_orders.items()):
            model = model_class()
            features = [
                col for col in df_features.columns
                if col != value_col
                # After split, the 1st element is sin{order} or cos{order}.
                # Takes the order to compare with the current order.
                and int(col.split("_")[0][3:]) <= order
            ]
            if order > 0:
                model.fit(df_features[features], df_features[value_col])
                pred = model.predict(df_features[features])
            else:
                pred = np.repeat([df_features[value_col].mean()], len(df_features))
            fig.add_trace(
                go.Scatter(
                    x=df[time_col],
                    y=pred,
                    name=name,
                    mode="lines",
                    showlegend=True,
                    line=dict(color=DEFAULT_PLOTLY_COLORS[2 + i])
                ),
                row=2,
                col=1
            )

        # Adds trend adjustment plot.
        if rows == 3:
            fig.add_trace(
                go.Scatter(
                    x=self.df[time_col],
                    y=self.df[value_col],
                    name="Original time series",
                    mode="lines",
                    showlegend=True
                ),
                row=3,
                col=1
            )
            fig.add_trace(
                go.Scatter(
                    x=df[time_col],
                    y=df[value_col],
                    name="Detrend/aggregated time series",
                    mode="lines",
                    showlegend=True,
                    line=dict(color="black")
                ),
                row=3,
                col=1
            )
            fig.add_trace(
                go.Scatter(
                    x=df[time_col],
                    y=df[self.FITTED_TREND_COL],
                    name="Fitted trend",
                    mode="lines",
                    showlegend=True
                ),
                row=3,
                col=1
            )

        fig.update_xaxes(
            title_text="Fourier series order",
            row=1,
            col=1
        )
        fig.update_xaxes(
            title_text="Timestamp",
            row=2,
            col=1
        )
        fig.update_yaxes(
            title_text="Criteria",
            row=1,
            col=1
        )
        fig.update_yaxes(
            title_text="Timeseries",
            row=2,
            col=1
        )
        if rows == 3:
            fig.update_xaxes(
                title_text="Timestamp",
                row=3,
                col=1
            )
            fig.update_yaxes(
                title_text="Timeseries",
                row=3,
                col=1
            )
        fig.update_layout(
            title_text="Inferred seasonality Fourier series orders",
            height=1000
        )
        return fig
